Skip to content

Avoid crash on a second epoch over a skip/take IterableDataset#4081

Open
robbiebusinessacc wants to merge 1 commit into
huggingface:mainfrom
robbiebusinessacc:contrib/iterable-dataset-epoch-source-shuffling
Open

Avoid crash on a second epoch over a skip/take IterableDataset#4081
robbiebusinessacc wants to merge 1 commit into
huggingface:mainfrom
robbiebusinessacc:contrib/iterable-dataset-epoch-source-shuffling

Conversation

@robbiebusinessacc

Copy link
Copy Markdown

What this fixes

Iterating a prepared HF datasets.IterableDataset for more than one epoch crashes
with DataSourcesShufflingDisallowed when the dataset was built via .skip() /
.take() (#4080).

DataLoaderShard.__iter__ and DataLoaderDispatcher.__iter__ call
self.set_epoch(self.iteration) at the start of every pass, forwarding a nonzero
epoch to the wrapped dataset. On the second epoch the dataset tries to reshuffle its
data sources at iteration time, which .skip()/.take() iterables forbid, so it
raises DataSourcesShufflingDisallowed.

The fix

Catch DataSourcesShufflingDisallowed while iterating and reset the dataset's epoch
to 0, then re-create the iterator. These datasets can't reshuffle their sources
between epochs anyway, so this lets them iterate across multiple epochs without
changing behaviour for datasets that do support per-epoch reshuffling. The exception
is imported defensively so datasets stays optional.

Repro (from the issue)

  import datasets, accelerate, torch.utils.data
  ds = datasets.Dataset.from_dict({"a": list(range(20))}).to_iterable_dataset().skip(10)
  dl = accelerate.Accelerator().prepare_data_loader(torch.utils.data.DataLoader(ds, batch_size=2))
  for b in dl: pass
  for b in dl: pass   # raised DataSourcesShufflingDisallowed before this fix

Tests

Adds test_iterable_dataset_blocked_source_shuffling_multiple_epochs, covering both
the dispatch and non-dispatch paths. It fails on main and passes with the fix; the
rest of tests/test_data_loader.py is unaffected.

Fixes #4080

`DataLoaderShard`/`DataLoaderDispatcher` call `set_epoch(self.iteration)` at the
start of each pass, forwarding a nonzero epoch to a wrapped HF
`datasets.IterableDataset`. For datasets built via `.skip()`/`.take()` this makes
iteration raise `DataSourcesShufflingDisallowed` on the second epoch, since those
datasets forbid reshuffling their data sources between epochs.

Catch the exception while iterating and reset the dataset epoch to 0 so such
datasets can still be iterated across multiple epochs; per-epoch source
reshuffling stays available for datasets that support it. Adds a regression test
covering both the dispatch and non-dispatch paths.

Fixes huggingface#4080
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug report] DataSourcesShufflingDisallowed when training using split datasets.IterableDataset

1 participant